{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"/home/maciej/repos/scikit-learn\")\n",
    "\n",
    "import cv2\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.datasets import make_classification\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from dataset_io import *\n",
    "import numpy as np\n",
    "from joblib import dump, load\n",
    "import timeit\n",
    "import random\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gradient features - RandomForest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pallet_X=np.load(PROJECT_PATH+\"pallet_rectangles_gradient.npy\")\n",
    "# background_X=np.load(PROJECT_PATH+\"background_rectangles_gradient.npy\")\n",
    "\n",
    "# pallet_X=[np.dstack(x) for x in pallet_X]\n",
    "# background_X=[np.dstack(x) for x in background_X]\n",
    "\n",
    "# print(np.array(pallet_X).shape)\n",
    "# print(np.array(background_X).shape)\n",
    "\n",
    "# pallet_X=[x.flatten() for x in pallet_X]\n",
    "# background_X=[x.flatten() for x in background_X]\n",
    "\n",
    "# pallet_y=np.ones(len(pallet_X))\n",
    "# background_y=np.zeros(len(background_X))\n",
    "\n",
    "# X=np.vstack((pallet_X,background_X))\n",
    "# y=np.hstack((pallet_y,background_y))\n",
    "\n",
    "# sss = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=0)\n",
    "# sss.get_n_splits(X, y)\n",
    "# train_index, test_index = list(sss.split(X, y))[0]\n",
    "# X_train, X_test = X[train_index], X[test_index]\n",
    "# y_train, y_test = y[train_index], y[test_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gradient_data(train_scenes, test_scenes):\n",
    "    def get_samples(feature, scenes, y_val):\n",
    "        X = [np.load(get_existing_clf_ds_filepath(feature, s)) for s in scenes]\n",
    "        X = np.concatenate(X)\n",
    "        X=[np.dstack(x) for x in X]\n",
    "        X=[x.flatten() for x in X]\n",
    "        y=np.full(len(X), y_val)\n",
    "        return X,y\n",
    "\n",
    "    train_pallet_X, train_pallet_y=get_samples(\"pallet_rectangles_gradient\", train_scenes, 1)\n",
    "    train_background_X, train_background_y=get_samples(\"background_rectangles_gradient\", train_scenes, 0)\n",
    "\n",
    "    test_pallet_X, test_pallet_y=get_samples(\"pallet_rectangles_gradient\", test_scenes, 1)\n",
    "    test_background_X, test_background_y=get_samples(\"background_rectangles_gradient\", test_scenes, 0)\n",
    "\n",
    "    X_train = np.vstack((train_pallet_X,train_background_X))\n",
    "    X_test = np.vstack((test_pallet_X,test_background_X))\n",
    "    y_train = np.hstack((train_pallet_y,train_background_y))\n",
    "    y_test = np.hstack((test_pallet_y,test_background_y))\n",
    "    \n",
    "    return X_train, X_test, y_train, y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc:  0.945985401459854  f1:  0.8438818565400844  prec:  1.0  rec:  0.7299270072992701\n"
     ]
    }
   ],
   "source": [
    "def classify_gradient(train_scenes, test_scenes, fold_name, persist=False)\n",
    "    X_train, X_test, y_train, y_test = get_gradient_data(train_scenes, test_scenes)\n",
    "    clf = RandomForestClassifier(128, n_jobs=-1) #128 is as fast as lower sizes and as accurate as greater sizes\n",
    "    clf.fit(X_train, y_train)\n",
    "    y_pred=clf.predict(X_test)\n",
    "\n",
    "    acc=accuracy_score(y_test, y_pred)\n",
    "    prec=precision_score(y_test, y_pred)\n",
    "    rec=recall_score(y_test, y_pred)\n",
    "    f1=f1_score(y_test, y_pred)\n",
    "    \n",
    "    if persist:\n",
    "        dump(clf, \"rand_forest_clf_{}.joblib\".format(fold_name)) \n",
    "    #print(\"acc: \",acc,\" f1: \",f1,\" prec: \",prec,\" rec: \",rec)\n",
    "    return acc, prec, rec, f1, clf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[548   0]\n",
      " [ 37 100]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x7fbc3855d240>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAAEGCAYAAAAjX4PvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dfVxUZf7/8dcB8hYxMBlASUVRWSXJLGAtLdgR8wYB4UtZftVy09bNH2rmXW03oljhTdq6LVlq3pSliIoWKKiYpaZmqNmNIiYqYCByoyIM8/uDr7OxJkw6w5yZ+Tx7zMYczpzzgcfy7rquc851KXq9Xo8QQqiEg6ULEEKI35JQEkKoioSSEEJVJJSEEKoioSSEUBUJJSGEqjhZugAhhHktXbmJ1s7NjNr3oft88PX1NXNF9ZNQEsLGtXZuxtg5W43a9/Ca581cTcMklISwB4pi6QqMJqEkhK1TFHBwtHQVRpNQEsIeKNZzTUtCSQibp0j3TQihMtJSEkKohoK0lIQQaqJIS0kIoTJy9U0IoR4y0C2EUBMF6b4JIVRGQkkIoR4KOEj3TQihFgoy0C2EUBPT3hIQEhJCy5YtcXBwwNHRkeTkZEpKSpg0aRLnzp2jXbt2LFq0iNatW6PX65kzZw67d++mWbNmzJs3jx49etR7fOvpaAohbp+iGPcy0sqVK9m0aRPJyckAJCUlERwcTHp6OsHBwSQlJQGQlZVFbm4u6enpzJ49m9dee63BY0soCWEPFAfjXrcpIyODiIgIACIiItixY0ed7YqiEBAQQGlpKYWFhfUeS0JJCHtg4pbSs88+S1RUFOvWrQOgqKgId3d3ANq2bUtRUREABQUFeHh4GD7n4eFBQUFBvce2uzGljF17yS8qt3QZNi3A715Ll2Dzqq5XEhAQYNzOivFjSsXFxURFRRnex8bGEhsbW2efjz/+GI1GQ1FREWPGjMHHx+e/Tqeg3MHNmnYXSvlF5YyN32LpMmzapW/etXQJNu/UTyf+2AeMvPrm5uZmGCe6FY1GA0CbNm3QarVkZ2fTpk0bCgsLcXd3p7CwEDc3N8O++fn5hs/m5+cbPn/LUo2qVAhhxYzsuhnRurly5Qrl5eWGr/fu3Yuvry8hISGkpKQAkJKSQmhoKIBhu16v58iRI7Rq1crQzbsVu2spCWGXTHRLQFFRERMmTABAp9MxZMgQ+vXrh7+/P3Fxcaxfvx4vLy8WLVoEQP/+/dm9ezdarZbmzZszd+7cBs8hoSSErTPhs2/e3t5s3rz5pu2urq6sXLny5lMrCq+++uofOoeEkhA2T2YJEEKojTyQK4RQD1liSQihJjJHtxBCbe7kZsbGJqEkhB2QUBJCqIv1ZJKEkhC2TlEUHBzk6psQQkWk+yaEUBUJJSGEulhPJkkoCWEPpKUkhFAPRUJJCKEiCnL1TQihNtbTUJJQEsIeSPdNCKEeMqYkhFAbCSUhhLpYTyZJKAlh6+TqmxBCXWRMSQihNhJKQgh1sZ5MklASwh5IS0kIoRoyyZsQQn2sp6EkoSSEPZDumxBCVSSUhBCqIqEkhFANRVEklIQQ6qI4WE8oWc91QiHEbbvRWmroZSydTkdERATjxo0D4OzZs8TExKDVaomLi+P69esAXL9+nbi4OLRaLTExMeTl5TV4bAklIeyAohj3MtZHH31E586dDe8TExMZPXo027dvx8XFhfXr1wPw2Wef4eLiwvbt2xk9ejSJiYkNHltCSQhbp5i2pZSfn8+uXbuIjo4GQK/Xs2/fPsLCwgCIjIwkIyMDgMzMTCIjIwEICwvj66+/Rq/X13t8GVMSwsYpGN8KKi4uJioqyvA+NjaW2NjYOvvMnTuXqVOnUlFRAcClS5dwcXHByak2Tjw8PCgoKACgoKAAT09PAJycnGjVqhWXLl3Czc3tljVIKKnAD1tfp6yiEl1NDdW6Gh5+6i3D9/7fyBDmTY6i/WPTKCqpwMW5GR/Gj8Lb0xUnR0cWfZTBqs37LFi9dUtP+4IXJ/8/dDodo58Zy9SXplu6JLNwMHKg283NjeTk5Ft+f+fOnbi5udGzZ0/2799vqvLqUE0oTZ8+nUcffZSBAwfecp/k5GT69u2LRqNpxMoax8Dn3qGopKLOtvaauwkN8uOXC8WGbeP+px8/5OQTHfdv7nF15ruNr/DJtm+oqtY1dslWT6fTETdxAls/30679u15OOhBhgwJx+9Pf7J0aaalKEaHUkMOHz5MZmYmWVlZVFZWUl5ezpw5cygtLaW6uhonJyfy8/MNf6MajYYLFy7g4eFBdXU1ZWVluLq61nsOqxpT2rhxI4WFhZYuo9G89eJwZr2TUqcPrgecWzYFoGXzply6fIVqXY2FKrRu3xw4QOfOXejk40OTJk2IiX2C1C2bLF2Wyd3ovplioHvKlClkZWWRmZnJggULCAoKYv78+QQGBpKWlgbU/p2GhIQAEBISwsaNGwFIS0sjKCiowbErs7WU8vLyGDt2LD169OD777/H19eXN998kw8++ICdO3dSWVnJ/fffzxtvvHFTkceOHWPevHlcuXIFV1dXEhISOHz4MMeOHePFF1+kWbNmrFu3jmXLljV4LGug1+vZsvTv6PV6Ptiwlw+T9zLkUX/OF5Zw9KdzdfZ975PdrF80jpz0ObRq2YyR0z5scOBQ/L7z58/Rvr234X27du05cMA8XRJLM/ffxdSpU5k0aRKLFi3Cz8+PmJgYAKKjo5k6dSparZbWrVuzcOHCBo9l1u7b6dOnmTNnDg888AAzZsxg7dq1PP300/z97383/CA7d+40pCpAVVUV8fHxLF26FDc3N7Zt28bChQtJSEhgzZo1vPTSS/j7+wM0eCxrETpmIecvXqatqzOp7/2dH3PzeemZMIb87d2b9tX+2Y/sH/MY+NxifLzvYeu//s7e2FOUVVyzQOXCWpgjkwIDAwkMDATA29vbcBvAbzVt2pTFixf/oeOaNZQ8PT154IEHAAgPD2fVqlW0b9+eZcuWce3aNUpKSvD19a0TJKdPn+ann35izJgxANTU1NC2bdvfPf7+/fvrPZa1OH/xMgAXL5WzOTObRx7wpUO7NhxYNwOAdu538/XaaTwy8m1Ghgcxf/l2AHLO/kruuSK6ddRw8PgZi9Vvrby82pGXd9bw/ty5PNq1a2fBisxE5uj+j//+RSiKwuuvv86GDRvw9PRkyZIlVFZW1tlHr9fj6+vLunXr6j12ZWVlg8eyBi2aNcHBQaH8SiUtmjXhL8HdmZv0OR1CZxj2+WHr6/R96i2KSio4m3+JRx/qxt5vT+Hu1oquHTWcPverBX8C69XnwQc5efJnck+fxqtdOz5b9wkrVq21dFkmp2D81Tc1MOtA9/nz5/n2228BSE1NNbSaXF1dqaioMAyM/VanTp0oLi42fK6qqoqff/4ZgJYtWxrujbgRQPUdyxq4t2lFxvJJ7F83nT2rp/L5nuNs/+rELfef9/4XBPXqxDefzmTbv19g1jubbrpqJ4zj5OTEwnfeZejgMAL8/Rge8z/8qUcPS5dlFqZ+zMSczNpS6tSpE2vWrGHmzJl06dKFJ598ksuXLzNkyBDuuecew9jQbzVp0oTFixcTHx9PWVkZOp2OUaNG4evrS2RkJK+++qphoDsmJqbeY1mD3HNFBMbOq3ef7oNfNXx94eJlhv7tn+Yuy24MfHwQAx8fZOkyzE4leWMURW+mSzd5eXmMHz+e1NRUcxz+tq3ZkMbY+C2WLsOmXfrm5gF6YVqnfjpBjz/5GbXvhswDzN1rXGt6dbQHfn7GHddcVHPzpBDCPP7IYyZqYLZQat++vepaSULYK7WMFxlDWkpC2AFruvomoSSErfuDcyVZmoSSEDaudkzJelJJQkkIO2BFmSShJIQ9kJaSEEI9TDifUmOQUBLCxsmYkhBCdawokySUhLAH0lISQqiKFWWShJIQNk8meRNCqIkCOMrVNyGEmlhRQ0lCSQjbp55ZJY0hoSSEjVMUsKLem4SSEPZAWkpCCFWxokySUBLC1imAoxWlkoSSEHZAum9CCPWwxZknc3NzWbBgASdPnqyzCm1GRobZChNCmIaCYlVX34xaIXfGjBk8+eSTODo68tFHHxEREUF4eLi5axNCmIiiGPdSA6NCqbKykuDgYADatWvHCy+8wO7du81amBDCdBwcFKNeamBU961JkybU1NTQoUMHVq9ejUajoaJC1q8XwhrU3jypjsAxhlEtpZkzZ3L16lVefvlljh8/zubNm3nrrbfMXZsQwkQUI18NqaysJDo6mvDwcAYPHszixYsBOHv2LDExMWi1WuLi4rh+/ToA169fJy4uDq1WS0xMDHl5eQ2ew6hQOnfuHC1btsTDw4OEhASWLFnC+fPnjfmoEEIFFEUx6tWQJk2asHLlSjZv3kxKSgp79uzhyJEjJCYmMnr0aLZv346Liwvr168H4LPPPsPFxYXt27czevRoEhMTGzyHUaGUlJRk1DYhhDo5KMa9GqIoCi1btgSgurqa6upqFEVh3759hIWFARAZGWm4Mp+ZmUlkZCQAYWFhfP311+j1+nrPUe+Y0u7du8nKyqKgoID4+HjD9vLychwdHRv+CYQQFvdHFg4oLi4mKirK8D42NpbY2Ng6++h0OqKiovjll18YMWIE3t7euLi44ORUGyceHh4UFBQAUFBQgKenJwBOTk60atWKS5cu4ebmdssa6g0ljUZDz549yczMpEePHobtLVu2ZMaMGUb9kEIIC/sDSyy5ubmRnJxc7z6Ojo5s2rSJ0tJSJkyYQE5OjimqNKg3lLp370737t0ZMmQIOp2O8+fP4+PjY9IChBDmZ46r/S4uLgQGBnLkyBFKS0uprq7GycmJ/Px8NBoNUNuwuXDhAh4eHlRXV1NWVoarq2v9tRpz8j179jBs2DDGjh0LwIkTJxg/fvwd/khCiMZwo/tmioHu4uJiSktLAbh27RpfffUVnTt3JjAwkLS0NAA2btxISEgIACEhIWzcuBGAtLQ0goKCGjyPUfcpvfvuu6xfv56RI0cC4Ofnx7lz54z5qBBCBUzVUCosLGT69OnodDr0ej0DBw7kscceo0uXLkyaNIlFixbh5+dHTEwMANHR0UydOhWtVkvr1q1ZuHBhg+cwKpRuDFAJIayTqW6e7N69OykpKTdt9/b2NtwG8FtNmzY13MtkLKNCqUuXLmzZsgWdTkdubi6rVq3i/vvv/0MnEkJYjhXd0G3cmNIrr7zCyZMnadKkCZMnT8bZ2ZlZs2aZuzYhhAkoig0++9a8eXMmTZrEpEmTzF2PEMIMrOnZt3pDac6cOcyaNeuWV9ree+89sxQlhDAtK8qk+kNp2LBhADzzzDONUkxj8O/mzZmshq8AiNt3Mr/c0iXYvOtVOqP3VWxp3beePXsC0KJFC8PXN+zcudN8VQkhTMqowWOVMHqg+6effjK8T01NZenSpWYrSghhOooCjg6KUS81MCqUFi9ezLRp0zh16hSffvopa9eu5cMPPzR3bUIIEzHVLAGNwairb97e3ixYsIAJEybg6enJhx9+SLNmzcxdmxDCRIwfU6p/WpHGUG8oDR06tM77y5cvo9PpDLeQb9myxXyVCSFMQkE9rSBj1BtKcslfCBugopVKjFFvKLVr1w6AefPmER0dTZcuXRqlKCGEadnMzZM3dO7cmZdfftkw49yQIUPkAV0hrIQCOFpPJhkXSjExMcTExJCTk0NycjLh4eH07t2bmJgYgoKCzF2jEOIOWVNLyeh7qnQ6HTk5OeTk5ODq6kq3bt1YsWKFPA8nhMrVTvJmPSvkGtVSmjt3Lrt27SIoKIjx48dz3333Gb53YwUDIYRKqegeJGMYFUrdunUjLi6OFi1a3PS935vYSQihJortdd82b958UyCNGjUKQAa8hVA5m+q+VVZWcvXqVS5dusTly5cNi8iVl5cb1nUSQqifo1oSxwj1htInn3zCypUrKSwsNKxyCeDs7MzTTz9t9uKEEHfOpu7oHjVqFKNGjWLVqlWGlUyEEFbGyga66x1Tev/99wEYOXIkn3/+eZ3vLViwwHxVCSFMylTrvjWGekNp27Zthq+TkpLqfG/Pnj3mqUgIYVI3um82MXXJjYHt//76994LIdRLLRO4GaPeUPptc+6/m3ZqaeoJIeqnqKgVZIx6Q+mHH36gd+/e6PV6Kisr6d27N1DbSrp+/XqjFCiEuHPW1IaoN5ROnDjRWHUIIczIAetJJaMeMxFCWDebaSkJIayfTd08KYSwAYp6lk8yhjWtUSeEuA21LSXFqFdDLly4wMiRIxk0aBCDBw9m5cqVAJSUlDBmzBgGDBjAmDFjuHz5MlB7USw+Ph6tVsvQoUM5fvx4g+eQUBLCDphqlgBHR0emT5/Otm3bWLduHWvXruXkyZMkJSURHBxMeno6wcHBhputs7KyyM3NJT09ndmzZ/Paa681eA4JJSHsgIORr4a4u7vTo0cPoPbBfB8fHwoKCsjIyCAiIgKAiIgIduzYAWDYrigKAQEBlJaWUlhY2GCtQggbVjufkumffcvLy+PEiRP06tWLoqIi3N3dAWjbti1FRUUAFBQU4OHhYfiMh4dHg9MeyUC3ELZOAUcj71MqLi4mKirK8D42NpbY2Nib9quoqGDixInMnDkTZ2fnuqe7w4d7JZSEsAPGRoSbmxvJycn17lNVVcXEiRMZOnQoAwYMAKBNmzYUFhbi7u5OYWEhbm5uAGg0GvLz8w2fzc/PR6PR1Ht86b4JYeNMOR2uXq9n1qxZ+Pj4MGbMGMP2kJAQUlJSAEhJSSE0NLTOdr1ez5EjR2jVqpWhm3cr0lISwg6Y6gH6Q4cOsWnTJrp27cqwYcMAmDx5Ms899xxxcXGsX78eLy8vFi1aBED//v3ZvXs3Wq2W5s2bM3fu3IZr1dvZHCTZx77Ho4OvpcuwaYWXKy1dgs27XpRL7149jdp376HvyHO4x6h972tWip+f352UdsekpSSEzVPPrJLGkFASwsbduKPbWkgoCWEHrOmKloSSELZOsa6ZYiWUhLBxCsbfp6QGEkpC2AEraihZVVfT5l27do2wR//MY39+gH4P9eKtOa8DEB72GCF9+xDStw/3de3AqCeHW7hS6/LKlOfpH9CJyNCHDNsuXyrmryPCGfxIAH8dEc7lkktA7c2BCf+YyqCHexGlDeL7o0csVbZJOaAY9VID1YdSaWkpa9asAWofABwyZIiFKzKfpk2bkpyazs6vDpGx9yCZO9I5eGA/m9N2krn3IJl7D9LnoUAGD42wdKlWZVjMU/xr1cY62z5YuoDAvv3ZuucIgX3788HS2sVV9+xM58zpU2zdc4RX31xM/MxJlijZpEw5n1JjsIpQ+vjjjy1dRqNQFIWW//dwY1VVFdXVVXUGKMtKS/kyaxePDxlmqRKtUp+gh2l9t2udbTvTtzIs+ikAhkU/xc60VMP28OFPoigKvXo/RFlpCRcL8m86prUx1WMmjUH1Y0rz58/nl19+YdiwYXTo0AGAs2fP4u7uzmuvvcaxY8cME08FBQVZuNo7p9Pp0PYL5HTOKZ7563geePA/XY7PUzfxSP/HaOXiYsEKbUPRrxdpq6mdUuMedw1Fv14EoDD/PB5e7Qz7aTzbUZh/3rCvdVJP18wYqg+lKVOm8PPPP7Np0yZOnTpFZGQkhw4dori4GIAtW7Zw6tQpnn32WdLS0mjatKmFK74zjo6OZO49yOWSEkY/FcOJ74/h96faxwk2rv+Up0aNaeAI4o9S1NRMMBNr+vFU3327oaamhoSEBO69914iIiI4dOgQ4eHhAHTu3BkvLy9Onz5t4SpNp/Xdd/PwI/3ZuSMdgKKiX/n20Df8JWyQhSuzDW3uaWvoll0syKdNm9pnw9w9vMg/f86wX8GFc7h7eFmkRlMxtuumluCymlBycHAwan5fa/brrxe5XFICwNWrV9m9M4Muvt0ASE1JRjtwEM2aNbNkiTbjUe0gNq2vvYCyaf0aHhswGIDHtIPYvOFj9Ho93x0+gHOr1lbedavlqChGvdRA9d23li1bUlFRcdP2Pn36sGXLFoKDgzl9+jQXLlzAx8fHAhWaTkH+BSaOfxadTkdNTQ3DIqMZ8HjtH0vKhk95YdJUC1donV6aMIZv9u2hpLiI0Ae7MWHKTJ6dMJkXnx/Fxk9W4dnem/lLa1fleCQkjKzMdAY93ItmzZsTP/9fFq7eNBQrGlOyiqlLpkyZwo8//oiPjw85OTmkpqZSWVl5WwPdMnWJ+cnUJeb3R6YuOfDtUSpaGNcF9agplKlLjDF//vybtjVt2pSEhAQLVCOE9bGmlpJVhJIQ4s5Y0QK5EkpC2ANpKQkhVKP2MRNLV2E8CSUhbJ2KnmszhoSSEHbAeiJJQkkI+2BFqSShJISNq5150npSSUJJCDtgRUNKEkpC2AMJJSGEqkj3TQihGorhf6yDhJIQdsCKMklCSQibZ2ULv0koCWEH5I5uIYSqWE8kWdF0uEKIO6AY+WrAjBkzCA4OrrP+YklJCWPGjGHAgAGMGTOGy5cvA7ULe8bHx6PVahk6dCjHjx83qlQJJSFs3I07uo35pyFRUVEsW7aszrakpCSCg4NJT08nODiYpKQkALKyssjNzSU9PZ3Zs2cbPce+hJIQdsBUq5k8+OCDtG7dus62jIwMIiJqV22OiIhgx44ddbYrikJAQAClpaUUFhY2eA4ZUxLCDhg7plRcXExUVJThfWxsLLGxsfV+pqioCHd3dwDatm1LUVERAAUFBXh4/GclGA8PDwoKCgz73oqEkhA2T6mz/Ht93NzcSE5Ovv0zKcaf61ak+yaErTPzYpRt2rQxdMsKCwtxc3MDQKPRkJ+fb9gvPz8fjUbT4PEklISwccZeeLvd9k1ISAgpKSkApKSkEBoaWme7Xq/nyJEjtGrVqsGuG0j3TQj7YKIblSZPnsyBAwe4dOkS/fr144UXXuC5554jLi6O9evX4+XlxaJFiwDo378/u3fvRqvV0rx5c+bOnWtcqdawGKUpyWKU5ieLUZrfH1mM8tvs4zRv28moffUlZ2QxSiGE+Rk7XqSGFoqEkhB2wIoefZNQEsIeyCRvQgjVUJCWkhBCZawokySUhLB5MsmbEEJtZJI3IYSqWE8kSSgJYR+sKJUklISwcbJstxBCdaxoSElCSQh7YEWZJKEkhO2784nXGpOEkhC27g4mcLMECSUhbJyV3Ttpf6FUU32doryTli7DpjlaugA74KCv/kP7S0tJxQICAixdghCNTm4JEEKoirSUhBCqIqEkhFANuaNbCKE+1pNJEkpC2AMryiQJJSFsntw8KYRQG2ua5E2W7RZ2yc7WYLUq0lJSIb1ej6IoHD16lEuXLtGpUye8vLxwdJR7pU3hxu8XoKqqCr1eT5MmTSxclfnIaibijimKQmZmJu+88w59+vRh/fr1hIaG8vjjj9v0H09j+G0grVixgqNHj3L27FkmTZrEAw88YLO/X2u6JUC6byp06tQpVq9ezQcffEDv3r35/vvvOXToEKmpqVRVVVm6PKt2I5C2b9/O1q1befHFF4mIiGDZsmXs2bMHgJqaGkuWaBaKYtxLDSSUVKCyspLS0lIACgoK8PLyYtq0afzyyy8sW7aMpKQkXF1dWb16NZ999pmMh9yG7OxsEhMTDe8vXrxInz598PT0ZMSIEQwfPpzZs2dTUFCAg4Pt/VkoRr7UwPZ++1bo0KFDrFmzhs8++4wRI0Zw9epVfH19ycnJITQ0FB8fH/z9/enYsSOBgYFWNWGXWrRq1YqvvvqKJUuWUFRUhKurKxUVFRQWFgIwaNAggoODqaiosHCl5lA7yZsxLzWQMSUV+POf/8zy5cs5ePAg8fHxuLm5AdCxY0f+8Y9/UF1dzbZt23j99dfp3Lmzhau1LkVFRRQVFdG1a1eGDx/ORx99xN69e3njjTfYuXMnq1evxs/Pj8rKSg4fPoyzs7OlSzY5NXXNjCEtJQv6bTds+PDhhIaGcuDAAXJzc7l+/Tp9+vRh5cqVODo68tprrxEcHGzBaq1TWVkZCQkJTJ06lX379vHWW2+Rm5vLjz/+yJw5c3B2dubrr79mx44dvPvuu7i7u1u6ZLMwZfctKyuLsLAwtFotSUlJJq9VWkoWcuMq0P79+ykpKcHR0ZHExETi4+P597//zUsvvcTBgwcpLy/nhRdesHS5Vqtjx450796ddevWERcXR69evXjooYdYtWoV586dY/z48QBcuXKFFi1aWLhaMzJRS0mn0/HGG2+wfPlyNBoN0dHRhISE0KVLF9OcAAkli1EUhZ07d7JkyRKefvppVqxYQXFxMS+//DKzZs0iISGBffv28fLLL1u6VKv3xBNP0K1bN1asWIG3tzeLFy/m2rVrDBo0iCZNmvDMM8/QvHlzS5dpVqa6JSA7O5sOHTrg7e0NwODBg8nIyJBQslbl5eXodDpat27N9evXSU1N5f3332f//v20aNGCfv36ATBnzhx+/PFHxo0bR+fOnevcWyP+uA4dOtChQwdatWrFggUL0Ov1tGzZkvvvv5+QkBAAm/79NrnLidyTJ4za9+LFi8yaNcvwPjY2ltjYWMP7goICPDw8DO81Gg3Z2dmmKxYJpUZTXl5OYmIiPXv2JDQ0FFdXV5ycnPjwww/Jzs5m3rx5eHp6smPHDtzc3Ojdu7fhs7b8B9OYQkNDueuuu3j77bcN/+7YsaOlyzI7X19fo/f18/Mz/MfRUmSgu5E4Oztz3333cejQIcNNekFBQaxbt47x48fTsWNHDh48SGJiIk5O8t8Kc+nXrx8rVqxg2bJlciXzNmg0GvLz8w3vCwoK0Gg0Jj2H/L+/EdTU1ODg4EBUVBSOjo7s3r0bvV5PUFAQcXFxvP7662i1WrKyspg2bRr33XefpUu2aW3atLF0CVbL39+f3Nxczp49i0ajYevWrcyfP9+k51D0cnuwWd0YD/r111+55557gNpHHNLS0ujfvz9hYWEcO3YMvV5P06ZN6dmzp4whCVXbvXs3c+fORafTMXz4cJ5//nmTHl9CqRHcuEmva9eu3H///QwYMICMjAzS0tIIDAxEq9Xi4uJi6TKFUAXpvpnZ/v37WbhwIUuWLMUEN9kAAAU+SURBVCExMZHvvvuOCxcuMGrUKKqrq/niiy94+OGHJZSE+D8SSmbw2+5XTk4OCxcu5PTp05w7d47o6GgyMjLQ6/WMHDmSPn36yBiHEL8hoWQGiqJw8OBBCgsL8fb2xtnZmV27drFkyRLatWtHRkYGP/zwAxcuXKB9+/aWLlcIVZFQMqEbLaTDhw/zyiuv4O/vj6OjI2VlZXz//ff06NGDgIAArly5wujRoyWQhPgdMtBtYtnZ2bz99ttMmTKFgIAAzp49y65du/jmm284e/Ysd911F2PHjmXAgAGWLlUIVZKWkomVlZVx8OBB9u3bR0BAAB4eHnh6etKxY0cSEhK4du0abdq0kcv+QtyC3NFtYn379mXJkiVs2LCB1NRU7rrrLlxcXPjyyy+prKw0DGpLIAnx+6SlZAZ/+ctfcHBw4MUXXyQ9PR1FUZgwYYJh8jYhxK1JS8lMQkJCePvttzlz5gz+/v6Ehoai1+tlfm0hGiAtJTMKDQ2ladOmzJw5k3vvvVcGt4Uwglx9awR79+7l3nvvNUyMJYS4NQklIYSqyJiSEEJVJJSEEKoioSSEUBUJJTvk5+fHsGHDGDJkCBMnTuTq1au3faz9+/czbty4eveZPn06X3zxhVHHO3r0KPHx8bddj7B+Ekp2qFmzZmzatMlwx/knn3xS5/t6vZ6amhqL1Obv7y/LStk5CSU716dPH86cOUNeXh5hYWG89NJLDBkyhAsXLvDll18SGxtLZGQkEydOpKKiAqhdIXXgwIFERkayfft2oHYe8gEDBlBcXGx4r9VqDe9vWLRoEdOnT0en05Gdnc0TTzxBeHg40dHRlJeX12l5ZWdnExsbS0REBE888QQ5OTmN+JsRliKhZMeqq6vJysqia9euAJw5c4YRI0awdetWmjdvzr/+9S+WL1/Oxo0b6dmzJ8uXL6eyspJXXnmF9957j+TkZC5evAiAg4MD4eHhbN68GYCvvvqK7t2713m05s0336S4uJiEhAR0Oh2TJk1i5syZbN68mRUrVtCsWbM69fn4+LBmzRpSUlKYOHEiCxcubKTfjLAkuaPbDl27do1hw4YBtS2l6OhoCgsL8fLyIiAgAIDvvvuOkydP8uSTTwJQVVVFQEAAOTk5tG/f3rBeWnh4OJ9++ikAw4cP529/+xujR49mw4YNREVFGc65dOlSevXqxezZswE4ffo0bdu2Nazc4uzsfFOdZWVlTJs2jTNnzqAoClVVVeb5hQhVkVCyQzfGlP5bixYtDF/r9Xr69u3LggUL6uxz4sStV1r19PSkTZs2fP3112RnZ5OYmGj4nr+/P8ePH6ekpIS7777bqDrfeecdAgMD+ec//0leXh7/+7//a9TnhHWT7pv4XQEBARw+fJgzZ84AcOXKFU6fPo2Pjw/nzp3jl19+AWDr1q11PhcTE8PUqVMZOHAgjo6Ohu2PPPIIf/3rXxk3bhzl5eV06tSJixcvGpZ8Li8vp7q6us6xysrKDAsdbty40Ww/q1AXCSXxu9zc3EhISGDy5MkMHTqU2NhYcnJyaNq0KW+88QbPPfcckZGRN03HEhISwpUrV+p03W54/PHHiYmJ4fnnn6empoaFCxcSHx9PeHg4zzzzDJWVlcB/5poaO3YsCxYsICIi4qbAErZLnn0TJnX06FESEhJYu3btbX0+LS2NzMxM3nzzTRNXJqyFtJSEySQlJTFx4kQmT558W5/PyMhg4cKFxMbGmrgyYU2kpSSEUBVpKQkhVEVCSQihKhJKQghVkVASQqiKhJIQQlUklIQQqvL/AdMdlw6fGiFPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x288 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def visualize_confusion_matrix(y_true, y_pred, classes,\n",
    "                          title=None):\n",
    "\n",
    "    sns.set_style(\"whitegrid\", {'axes.grid' : False})\n",
    "    cm = confusion_matrix(y_true, y_pred)\n",
    "    print(cm)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(4,4))\n",
    "    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n",
    "    ax.figure.colorbar(im, ax=ax)\n",
    "    \n",
    "    #setting labels\n",
    "    ax.set(xticks=np.arange(cm.shape[1]),\n",
    "           yticks=np.arange(cm.shape[0]),\n",
    "\n",
    "           xticklabels=classes, yticklabels=classes,\n",
    "           title=title,\n",
    "           ylabel='Etykieta',\n",
    "           xlabel='Predykcja')\n",
    "\n",
    "    # label rotation\n",
    "    plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n",
    "             rotation_mode=\"anchor\")\n",
    "\n",
    "    # text annotations\n",
    "    fmt = 'd'\n",
    "    thresh = cm.max() / 2.\n",
    "    for i in range(cm.shape[0]):\n",
    "        for j in range(cm.shape[1]):\n",
    "            ax.text(j, i, format(cm[i, j], fmt),\n",
    "                    ha=\"center\", va=\"center\",\n",
    "                    color=\"white\" if cm[i, j] > thresh else \"black\")\n",
    "    #bug workarond\n",
    "    ax.set_ylim(len(classes)-0.5, -0.5)\n",
    "    return ax\n",
    "\n",
    "#visualize_confusion_matrix(y_test, y_pred, [\"paleta\", \"tło\"], title=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Color features - NaiveBayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pallet_pix_counts=[41068636, 2327190, 22568602]\n",
    "# pallet_pix_count_map=dict(zip(COLORS, pallet_pix_counts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def readnsplit(color):\n",
    "#     #-------------y-------------------------\n",
    "#     pallet_pix_count=pallet_pix_count_map[color]\n",
    "#     background_pix_count=pallet_pix_count*2\n",
    "    \n",
    "#     pallet_y=np.ones(pallet_pix_count, dtype=bool)\n",
    "#     background_y=np.zeros(background_pix_count, dtype=bool)\n",
    "\n",
    "#     y=np.hstack((pallet_y,background_y))\n",
    "\n",
    "#     pallet_y=None\n",
    "#     background_y=None\n",
    "\n",
    "#     sss = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=0)\n",
    "#     sss.get_n_splits(y, y)\n",
    "#     train_index, test_index = list(sss.split(y, y))[0]\n",
    "#     sss=None\n",
    "\n",
    "#     y_train, y_test = y[train_index], y[test_index]\n",
    "#     y=None\n",
    "\n",
    "#     #-------------X-------------------------\n",
    "#     background_X=np.load(PROJECT_PATH+\"backgrounds_color.npy\", allow_pickle=True)\n",
    "#     background_X=np.array([np.hstack(x) for x in background_X])\n",
    "#     background_X=np.concatenate(background_X, 0)\n",
    "#     indices=range(len(background_X))\n",
    "#     indices=random.sample(indices, background_pix_count)\n",
    "#     background_X=background_X[indices]\n",
    "    \n",
    "#     pallet_X=np.load(PROJECT_PATH+\"pallets_color_{}.npy\".format(color), allow_pickle=True)\n",
    "#     pallet_X=np.array([np.hstack(x) for x in pallet_X])\n",
    "#     pallet_X=np.concatenate(pallet_X, 0)\n",
    "\n",
    "#     X=np.vstack((pallet_X,background_X))\n",
    "\n",
    "#     pallet_X=None\n",
    "#     background_X=None\n",
    "\n",
    "#     X_train, X_test = X[train_index], X[test_index]\n",
    "#     X=None\n",
    "#     return X_train, X_test, y_train, y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_color_data(train_scenes, test_scenes, color):\n",
    "    def get_samples(feature, scenes, y_val, color=None):\n",
    "        X = [np.load(get_existing_clf_ds_filepath(feature, s, color), allow_pickle=True) for s in scenes]\n",
    "        X = [x for x in X if len(x) > 0]\n",
    "        X = np.concatenate(X)\n",
    "        X = np.array([np.hstack(x) for x in X])\n",
    "        X= np.concatenate(X, 0)\n",
    "        y=np.full(len(X), y_val)\n",
    "        return X,y\n",
    "\n",
    "    train_pallet_X, train_pallet_y=get_samples(\"pallets_color\", train_scenes, 1, color)\n",
    "    train_background_X, train_background_y=get_samples(\"backgrounds_color\", train_scenes, 0)\n",
    "\n",
    "    test_pallet_X, test_pallet_y=get_samples(\"pallets_color\", test_scenes, 1, color)\n",
    "    test_background_X, test_background_y=get_samples(\"backgrounds_color\", test_scenes, 0)\n",
    "\n",
    "    X_train = np.vstack((train_pallet_X,train_background_X))\n",
    "    X_test = np.vstack((test_pallet_X,test_background_X))\n",
    "    y_train = np.hstack((train_pallet_y,train_background_y))\n",
    "    y_test = np.hstack((test_pallet_y,test_background_y))\n",
    "    \n",
    "    return X_train, X_test, y_train, y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classify_color(color, train_scenes, test_scenes, fold_name, persist):  \n",
    "    X_train, X_test, y_train, y_test=get_color_data(train_scenes, test_scenes, color)\n",
    "    clf = GaussianNB()\n",
    "    clf.fit(X_train, y_train)\n",
    "    X_train=None\n",
    "    y_train=None\n",
    "    y_pred=clf.predict(X_test)\n",
    "    X_test=None\n",
    "\n",
    "    acc=accuracy_score(y_test, y_pred)\n",
    "    f1=f1_score(y_test, y_pred)\n",
    "    prec=precision_score(y_test, y_pred)\n",
    "    rec=recall_score(y_test, y_pred)\n",
    "    cm=confusion_matrix(y_test, y_pred, labels=[True, False])\n",
    "\n",
    "    if persist:\n",
    "        dump(clf, 'naive_bayes_clf_{}_{}.joblib'.format(color, fold_name))\n",
    "    #     print(\"acc: \",acc,\" f1: \",f1,\" prec: \",prec,\" rec: \",rec)\n",
    "    #     print(cm)\n",
    "    return acc, prec, rec, f1, clf\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "blue\n",
      "acc:  0.9162741484906277  f1:  0.8451655212794522  prec:  0.7801153604387941  rec:  0.9220510320372187\n",
      "[[2212798  187067]\n",
      " [ 623703 6660060]]\n",
      "dark\n",
      "acc:  0.9203908959244943  f1:  0.6519227719029129  prec:  0.558045092505196  rec:  0.783773899846273\n",
      "[[ 600091  165552]\n",
      " [ 475254 6808509]]\n",
      "wooden\n",
      "acc:  0.9503026851177495  f1:  0.7017684616163444  prec:  0.5988835548201028  rec:  0.8473361623623071\n",
      "[[ 457459   82420]\n",
      " [ 306394 6977369]]\n"
     ]
    }
   ],
   "source": [
    "def classify_colors(train_scenes, test_scenes, fold_name, persist=False):\n",
    "    for color in COLORS:\n",
    "        classify_color(color, train_scenes, test_scenes, fold_name, persist)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
